-
Notifications
You must be signed in to change notification settings - Fork 104
Don't unnecessarily wrap the elem in PythonTensor #554
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Instead of saying that a PythonTensor has a regular (e.g., CPU) tensor and an FX proxy, a PythonTensor *is a* regular CPU tensor, that also carries an FX proxy (that updates as we go along). This should fix #465 and it also fixed some expected failures in the test suite. Signed-off-by: Edward Z. Yang <[email protected]>
Signed-off-by: Edward Z. Yang <[email protected]>
I’ll review the rest in a bit, but sadly this doesn’t fix #465 - that’s also a problem for vmap, not just AOTautograd. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More or less fine with this - I feel like this is just sidestepping the real problem though, which is that we have some issues with wrapper tensor subclasses.
functorch/_src/python_key.py
Outdated
# TODO: this might not actually work, I didn't test it when | ||
# I changed device derivation to work off of the types of the | ||
# input devices | ||
args = pytree.tree_map(lambda x: torch.ones_like(x, device=x.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't work, since at this point the args will be meta tensors, and the device will simply be the meta device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not super opposed to ripping out the meta tensor code entirely though, and reimplementing it some different way if you think there's a better way to do it :P
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, let's dump the meta tensor code for now and I'll reimplement it shortly
# PythonTensor boundary. | ||
# assert not elem.requires_grad or not torch.is_grad_enabled() | ||
|
||
r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't work for meta tensors since at this point, elem
will be a meta tensor. So we're just gonna make a PythonTensor with the meta device anyways.
That's why I went through all of the shenanigans of inferring the output device - if we run with meta tensors, then at no point do we have the actual output device of the operator. All you have is the device of the input tensors.
So... ripping out the device inference logic will make the meta-tracing stuff not work at all, in which case we should just remove all of it :P
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a bigger structural problem for meta tensors. Will need to think about this...
Signed-off-by: Edward Z. Yang <[email protected]>
Signed-off-by: Edward Z. Yang <[email protected]>
…h/functorch#554) * Don't unnecessarily wrap the elem in PythonTensor Instead of saying that a PythonTensor has a regular (e.g., CPU) tensor and an FX proxy, a PythonTensor *is a* regular CPU tensor, that also carries an FX proxy (that updates as we go along). This should fix pytorch/functorch#465 and it also fixed some expected failures in the test suite. This kills the meta variant logic entirely; maybe some other time we'll try to bring it back. Signed-off-by: Edward Z. Yang <[email protected]>
…h/functorch#554) * Don't unnecessarily wrap the elem in PythonTensor Instead of saying that a PythonTensor has a regular (e.g., CPU) tensor and an FX proxy, a PythonTensor *is a* regular CPU tensor, that also carries an FX proxy (that updates as we go along). This should fix pytorch/functorch#465 and it also fixed some expected failures in the test suite. This kills the meta variant logic entirely; maybe some other time we'll try to bring it back. Signed-off-by: Edward Z. Yang <[email protected]>
Instead of saying that a PythonTensor has a regular (e.g., CPU) tensor
and an FX proxy, a PythonTensor is a regular CPU tensor, that also
carries an FX proxy (that updates as we go along).
Partially addresses #465 and
it also fixed some expected failures in the test suite.